{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UEBilEjLj5wY"
   },
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 119
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 536,
     "status": "ok",
     "timestamp": 1524974472601,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "GOzuY8Yvj5wb",
    "outputId": "c19362ce-f87a-4cc2-84cc-8d7b4b9e6007"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.7.3\n",
      "IPython 7.9.0\n",
      "\n",
      "torch 1.3.0\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rH4XmErYj5wm"
   },
   "source": [
    "# Filter Response Normalization for Network-in-Network CIFAR-10 Classifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The CNN architecture is based on \n",
    "\n",
    "- Lin, Min, Qiang Chen, and Shuicheng Yan. \"[Network in network](https://arxiv.org/abs/1312.4400).\" arXiv preprint arXiv:1312.4400 (2013).\n",
    "\n",
    "This notebook implements Filter Response Normalization as a drop-in replacement for BatchNorm, based on the paper:\n",
    "\n",
    "- S. Singh and S. Krishnan (2019). **Filter Response Normalization Layer: Eliminating Batch Dependence in the Training of Deep Neural Networks.** https://arxiv.org/abs/1911.09737"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "class FilterResponseNormalization(nn.Module):\n",
    "    def __init__(self, num_features, eps=1e-6):\n",
    "        super(FilterResponseNormalization, self).__init__()\n",
    "        \n",
    "        self.register_parameter('beta', \n",
    "                                torch.nn.Parameter(\n",
    "                                        torch.empty([1, num_features, 1, 1]).normal_()))\n",
    "    \n",
    "        self.register_parameter('gamma', \n",
    "                                torch.nn.Parameter(\n",
    "                                        torch.empty([1, num_features, 1, 1]).normal_()))\n",
    "        \n",
    "        self.register_parameter('tau', \n",
    "                                torch.nn.Parameter(\n",
    "                                        torch.empty([1, num_features, 1, 1]).normal_()))\n",
    "        \n",
    "        self.eps = torch.Tensor([eps])\n",
    "\n",
    "    def forward(self, x):\n",
    "        # forward function based on\n",
    "        # https://github.com/gupta-abhay/pytorch-frn/blob/master/frn.py\n",
    "        n, c, h, w = x.size()\n",
    "        \n",
    "        self.eps = self.eps.to(self.tau.device)\n",
    "\n",
    "        nu2 = torch.mean(x.pow(2), (2, 3), keepdims=True)\n",
    "        x = x * torch.rsqrt(nu2 + torch.abs(self.eps))\n",
    "        return torch.max(self.gamma*x + self.beta, self.tau)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MkoGLH_Tj5wn"
   },
   "source": [
    "## Additional Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "ORj09gnrj5wp"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data import Subset\n",
    "\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "I6hghKPxj5w0"
   },
   "source": [
    "## Model Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 85
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 23936,
     "status": "ok",
     "timestamp": 1524974497505,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "NnT0sZIwj5wu",
    "outputId": "55aed925-d17e-4c6a-8c71-0d9b3bde5637"
   },
   "outputs": [],
   "source": [
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Hyperparameters\n",
    "RANDOM_SEED = 1\n",
    "LEARNING_RATE = 0.00005\n",
    "BATCH_SIZE = 256\n",
    "NUM_EPOCHS = 100\n",
    "\n",
    "# Architecture\n",
    "NUM_CLASSES = 10\n",
    "\n",
    "# Other\n",
    "DEVICE = \"cuda:0\"\n",
    "GRAYSCALE = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Image batch dimensions: torch.Size([256, 3, 32, 32])\n",
      "Image label dimensions: torch.Size([256])\n",
      "Image batch dimensions: torch.Size([256, 3, 32, 32])\n",
      "Image label dimensions: torch.Size([256])\n",
      "Image batch dimensions: torch.Size([256, 3, 32, 32])\n",
      "Image label dimensions: torch.Size([256])\n"
     ]
    }
   ],
   "source": [
    "##########################\n",
    "### CIFAR-10 Dataset\n",
    "##########################\n",
    "\n",
    "\n",
    "# Note transforms.ToTensor() scales input images\n",
    "# to 0-1 range\n",
    "\n",
    "\n",
    "train_indices = torch.arange(0, 49000)\n",
    "valid_indices = torch.arange(49000, 50000)\n",
    "\n",
    "\n",
    "train_and_valid = datasets.CIFAR10(root='data', \n",
    "                                   train=True, \n",
    "                                   transform=transforms.ToTensor(),\n",
    "                                   download=True)\n",
    "\n",
    "train_dataset = Subset(train_and_valid, train_indices)\n",
    "valid_dataset = Subset(train_and_valid, valid_indices)\n",
    "\n",
    "\n",
    "test_dataset = datasets.CIFAR10(root='data', \n",
    "                                train=False, \n",
    "                                transform=transforms.ToTensor())\n",
    "\n",
    "\n",
    "#####################################################\n",
    "### Data Loaders\n",
    "#####################################################\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset, \n",
    "                          batch_size=BATCH_SIZE,\n",
    "                          num_workers=8,\n",
    "                          shuffle=True)\n",
    "\n",
    "valid_loader = DataLoader(dataset=valid_dataset, \n",
    "                          batch_size=BATCH_SIZE,\n",
    "                          num_workers=8,\n",
    "                          shuffle=False)\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset, \n",
    "                         batch_size=BATCH_SIZE,\n",
    "                         num_workers=8,\n",
    "                         shuffle=False)\n",
    "\n",
    "#####################################################\n",
    "\n",
    "# Checking the dataset\n",
    "for images, labels in train_loader:  \n",
    "    print('Image batch dimensions:', images.shape)\n",
    "    print('Image label dimensions:', labels.shape)\n",
    "    break\n",
    "\n",
    "for images, labels in test_loader:  \n",
    "    print('Image batch dimensions:', images.shape)\n",
    "    print('Image label dimensions:', labels.shape)\n",
    "    break\n",
    "    \n",
    "for images, labels in valid_loader:  \n",
    "    print('Image batch dimensions:', images.shape)\n",
    "    print('Image label dimensions:', labels.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Filter Response Normalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### MODEL\n",
    "##########################\n",
    "\n",
    "\n",
    "class NiN(nn.Module):\n",
    "    def __init__(self, num_classes):\n",
    "        super(NiN, self).__init__()\n",
    "        self.num_classes = num_classes\n",
    "        self.classifier = nn.Sequential(\n",
    "                nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2, bias=False),\n",
    "                FilterResponseNormalization(192),\n",
    "                #nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(192, 160, kernel_size=1, stride=1, padding=0, bias=False),\n",
    "                FilterResponseNormalization(160),\n",
    "                #nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(160,  96, kernel_size=1, stride=1, padding=0, bias=False),\n",
    "                FilterResponseNormalization(96),\n",
    "                #nn.ReLU(inplace=True),\n",
    "                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),\n",
    "                nn.Dropout(0.5),\n",
    "\n",
    "                nn.Conv2d(96, 192, kernel_size=5, stride=1, padding=2, bias=False),\n",
    "                FilterResponseNormalization(192),\n",
    "                #nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0, bias=False),\n",
    "                FilterResponseNormalization(192),\n",
    "                #nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0, bias=False),\n",
    "                FilterResponseNormalization(192),\n",
    "                #nn.ReLU(inplace=True),\n",
    "                nn.AvgPool2d(kernel_size=3, stride=2, padding=1),\n",
    "                nn.Dropout(0.5),\n",
    "\n",
    "                nn.Conv2d(192, 192, kernel_size=3, stride=1, padding=1, bias=False),\n",
    "                FilterResponseNormalization(192),\n",
    "                #nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0, bias=False),\n",
    "                FilterResponseNormalization(192),\n",
    "                #nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(192,  10, kernel_size=1, stride=1, padding=0),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.AvgPool2d(kernel_size=8, stride=1, padding=0),\n",
    "\n",
    "                )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.classifier(x)\n",
    "        logits = x.view(x.size(0), self.num_classes)\n",
    "        probas = torch.softmax(logits, dim=1)\n",
    "        return logits, probas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(RANDOM_SEED)\n",
    "\n",
    "model = NiN(NUM_CLASSES)\n",
    "model.to(DEVICE)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/100 | Batch 000/192 | Cost: 2.3138\n",
      "Epoch: 001/100 | Batch 120/192 | Cost: 2.1962\n",
      "Epoch: 001/100 Train Acc.: 22.35% | Validation Acc.: 24.80%\n",
      "Time elapsed: 0.60 min\n",
      "Epoch: 002/100 | Batch 000/192 | Cost: 2.1328\n",
      "Epoch: 002/100 | Batch 120/192 | Cost: 2.0764\n",
      "Epoch: 002/100 Train Acc.: 24.54% | Validation Acc.: 26.70%\n",
      "Time elapsed: 1.21 min\n",
      "Epoch: 003/100 | Batch 000/192 | Cost: 2.0738\n",
      "Epoch: 003/100 | Batch 120/192 | Cost: 1.9631\n",
      "Epoch: 003/100 Train Acc.: 28.22% | Validation Acc.: 32.10%\n",
      "Time elapsed: 1.80 min\n",
      "Epoch: 004/100 | Batch 000/192 | Cost: 2.0147\n",
      "Epoch: 004/100 | Batch 120/192 | Cost: 1.9606\n",
      "Epoch: 004/100 Train Acc.: 29.99% | Validation Acc.: 31.10%\n",
      "Time elapsed: 2.40 min\n",
      "Epoch: 005/100 | Batch 000/192 | Cost: 1.9880\n",
      "Epoch: 005/100 | Batch 120/192 | Cost: 1.9411\n",
      "Epoch: 005/100 Train Acc.: 31.99% | Validation Acc.: 33.40%\n",
      "Time elapsed: 3.00 min\n",
      "Epoch: 006/100 | Batch 000/192 | Cost: 1.9271\n",
      "Epoch: 006/100 | Batch 120/192 | Cost: 1.9257\n",
      "Epoch: 006/100 Train Acc.: 36.23% | Validation Acc.: 36.90%\n",
      "Time elapsed: 3.60 min\n",
      "Epoch: 007/100 | Batch 000/192 | Cost: 1.9154\n",
      "Epoch: 007/100 | Batch 120/192 | Cost: 1.9543\n",
      "Epoch: 007/100 Train Acc.: 37.08% | Validation Acc.: 39.60%\n",
      "Time elapsed: 4.19 min\n",
      "Epoch: 008/100 | Batch 000/192 | Cost: 1.7458\n",
      "Epoch: 008/100 | Batch 120/192 | Cost: 1.7728\n",
      "Epoch: 008/100 Train Acc.: 40.16% | Validation Acc.: 41.90%\n",
      "Time elapsed: 4.79 min\n",
      "Epoch: 009/100 | Batch 000/192 | Cost: 1.7029\n",
      "Epoch: 009/100 | Batch 120/192 | Cost: 1.5784\n",
      "Epoch: 009/100 Train Acc.: 42.81% | Validation Acc.: 42.80%\n",
      "Time elapsed: 5.39 min\n",
      "Epoch: 010/100 | Batch 000/192 | Cost: 1.6909\n",
      "Epoch: 010/100 | Batch 120/192 | Cost: 1.6210\n",
      "Epoch: 010/100 Train Acc.: 44.37% | Validation Acc.: 44.80%\n",
      "Time elapsed: 5.98 min\n",
      "Epoch: 011/100 | Batch 000/192 | Cost: 1.5235\n",
      "Epoch: 011/100 | Batch 120/192 | Cost: 1.7265\n",
      "Epoch: 011/100 Train Acc.: 43.96% | Validation Acc.: 41.70%\n",
      "Time elapsed: 6.58 min\n",
      "Epoch: 012/100 | Batch 000/192 | Cost: 1.6440\n",
      "Epoch: 012/100 | Batch 120/192 | Cost: 1.6264\n",
      "Epoch: 012/100 Train Acc.: 45.15% | Validation Acc.: 47.90%\n",
      "Time elapsed: 7.18 min\n",
      "Epoch: 013/100 | Batch 000/192 | Cost: 1.5785\n",
      "Epoch: 013/100 | Batch 120/192 | Cost: 1.4762\n",
      "Epoch: 013/100 Train Acc.: 47.02% | Validation Acc.: 47.70%\n",
      "Time elapsed: 7.78 min\n",
      "Epoch: 014/100 | Batch 000/192 | Cost: 1.5027\n",
      "Epoch: 014/100 | Batch 120/192 | Cost: 1.4684\n",
      "Epoch: 014/100 Train Acc.: 49.24% | Validation Acc.: 51.60%\n",
      "Time elapsed: 8.37 min\n",
      "Epoch: 015/100 | Batch 000/192 | Cost: 1.4039\n",
      "Epoch: 015/100 | Batch 120/192 | Cost: 1.3735\n",
      "Epoch: 015/100 Train Acc.: 50.82% | Validation Acc.: 49.60%\n",
      "Time elapsed: 8.97 min\n",
      "Epoch: 016/100 | Batch 000/192 | Cost: 1.3997\n",
      "Epoch: 016/100 | Batch 120/192 | Cost: 1.2700\n",
      "Epoch: 016/100 Train Acc.: 53.31% | Validation Acc.: 52.50%\n",
      "Time elapsed: 9.57 min\n",
      "Epoch: 017/100 | Batch 000/192 | Cost: 1.2491\n",
      "Epoch: 017/100 | Batch 120/192 | Cost: 1.3126\n",
      "Epoch: 017/100 Train Acc.: 53.79% | Validation Acc.: 54.50%\n",
      "Time elapsed: 10.17 min\n",
      "Epoch: 018/100 | Batch 000/192 | Cost: 1.1533\n",
      "Epoch: 018/100 | Batch 120/192 | Cost: 1.2444\n",
      "Epoch: 018/100 Train Acc.: 54.08% | Validation Acc.: 54.70%\n",
      "Time elapsed: 10.76 min\n",
      "Epoch: 019/100 | Batch 000/192 | Cost: 1.3304\n",
      "Epoch: 019/100 | Batch 120/192 | Cost: 1.2897\n",
      "Epoch: 019/100 Train Acc.: 50.59% | Validation Acc.: 49.70%\n",
      "Time elapsed: 11.36 min\n",
      "Epoch: 020/100 | Batch 000/192 | Cost: 1.3537\n",
      "Epoch: 020/100 | Batch 120/192 | Cost: 1.3673\n",
      "Epoch: 020/100 Train Acc.: 56.56% | Validation Acc.: 58.20%\n",
      "Time elapsed: 11.96 min\n",
      "Epoch: 021/100 | Batch 000/192 | Cost: 1.2397\n",
      "Epoch: 021/100 | Batch 120/192 | Cost: 1.2370\n",
      "Epoch: 021/100 Train Acc.: 56.75% | Validation Acc.: 55.90%\n",
      "Time elapsed: 12.56 min\n",
      "Epoch: 022/100 | Batch 000/192 | Cost: 1.1461\n",
      "Epoch: 022/100 | Batch 120/192 | Cost: 1.2558\n",
      "Epoch: 022/100 Train Acc.: 57.97% | Validation Acc.: 56.80%\n",
      "Time elapsed: 13.15 min\n",
      "Epoch: 023/100 | Batch 000/192 | Cost: 1.2667\n",
      "Epoch: 023/100 | Batch 120/192 | Cost: 1.1730\n",
      "Epoch: 023/100 Train Acc.: 57.39% | Validation Acc.: 58.30%\n",
      "Time elapsed: 13.75 min\n",
      "Epoch: 024/100 | Batch 000/192 | Cost: 1.1891\n",
      "Epoch: 024/100 | Batch 120/192 | Cost: 1.1483\n",
      "Epoch: 024/100 Train Acc.: 58.32% | Validation Acc.: 59.20%\n",
      "Time elapsed: 14.35 min\n",
      "Epoch: 025/100 | Batch 000/192 | Cost: 1.2057\n",
      "Epoch: 025/100 | Batch 120/192 | Cost: 1.1603\n",
      "Epoch: 025/100 Train Acc.: 59.83% | Validation Acc.: 58.80%\n",
      "Time elapsed: 14.94 min\n",
      "Epoch: 026/100 | Batch 000/192 | Cost: 1.0880\n",
      "Epoch: 026/100 | Batch 120/192 | Cost: 1.2100\n",
      "Epoch: 026/100 Train Acc.: 60.84% | Validation Acc.: 59.70%\n",
      "Time elapsed: 15.54 min\n",
      "Epoch: 027/100 | Batch 000/192 | Cost: 1.0250\n",
      "Epoch: 027/100 | Batch 120/192 | Cost: 1.2468\n",
      "Epoch: 027/100 Train Acc.: 60.00% | Validation Acc.: 58.70%\n",
      "Time elapsed: 16.14 min\n",
      "Epoch: 028/100 | Batch 000/192 | Cost: 1.1053\n",
      "Epoch: 028/100 | Batch 120/192 | Cost: 1.1635\n",
      "Epoch: 028/100 Train Acc.: 60.78% | Validation Acc.: 58.50%\n",
      "Time elapsed: 16.74 min\n",
      "Epoch: 029/100 | Batch 000/192 | Cost: 1.0719\n",
      "Epoch: 029/100 | Batch 120/192 | Cost: 1.1288\n",
      "Epoch: 029/100 Train Acc.: 61.81% | Validation Acc.: 60.30%\n",
      "Time elapsed: 17.33 min\n",
      "Epoch: 030/100 | Batch 000/192 | Cost: 1.1443\n",
      "Epoch: 030/100 | Batch 120/192 | Cost: 1.0943\n",
      "Epoch: 030/100 Train Acc.: 60.08% | Validation Acc.: 56.70%\n",
      "Time elapsed: 17.93 min\n",
      "Epoch: 031/100 | Batch 000/192 | Cost: 1.0960\n",
      "Epoch: 031/100 | Batch 120/192 | Cost: 1.1090\n",
      "Epoch: 031/100 Train Acc.: 62.24% | Validation Acc.: 61.20%\n",
      "Time elapsed: 18.52 min\n",
      "Epoch: 032/100 | Batch 000/192 | Cost: 1.0046\n",
      "Epoch: 032/100 | Batch 120/192 | Cost: 1.2110\n",
      "Epoch: 032/100 Train Acc.: 60.71% | Validation Acc.: 62.30%\n",
      "Time elapsed: 19.12 min\n",
      "Epoch: 033/100 | Batch 000/192 | Cost: 1.0136\n",
      "Epoch: 033/100 | Batch 120/192 | Cost: 1.0387\n",
      "Epoch: 033/100 Train Acc.: 62.54% | Validation Acc.: 60.90%\n",
      "Time elapsed: 19.72 min\n",
      "Epoch: 034/100 | Batch 000/192 | Cost: 1.1746\n",
      "Epoch: 034/100 | Batch 120/192 | Cost: 1.1025\n",
      "Epoch: 034/100 Train Acc.: 64.18% | Validation Acc.: 64.10%\n",
      "Time elapsed: 20.32 min\n",
      "Epoch: 035/100 | Batch 000/192 | Cost: 0.9348\n",
      "Epoch: 035/100 | Batch 120/192 | Cost: 1.0161\n",
      "Epoch: 035/100 Train Acc.: 64.49% | Validation Acc.: 62.80%\n",
      "Time elapsed: 20.91 min\n",
      "Epoch: 036/100 | Batch 000/192 | Cost: 0.9798\n",
      "Epoch: 036/100 | Batch 120/192 | Cost: 1.0668\n",
      "Epoch: 036/100 Train Acc.: 63.90% | Validation Acc.: 63.00%\n",
      "Time elapsed: 21.51 min\n",
      "Epoch: 037/100 | Batch 000/192 | Cost: 0.9383\n",
      "Epoch: 037/100 | Batch 120/192 | Cost: 1.0818\n",
      "Epoch: 037/100 Train Acc.: 62.82% | Validation Acc.: 61.30%\n",
      "Time elapsed: 22.11 min\n",
      "Epoch: 038/100 | Batch 000/192 | Cost: 0.9943\n",
      "Epoch: 038/100 | Batch 120/192 | Cost: 0.9690\n",
      "Epoch: 038/100 Train Acc.: 64.09% | Validation Acc.: 62.90%\n",
      "Time elapsed: 22.71 min\n",
      "Epoch: 039/100 | Batch 000/192 | Cost: 1.0250\n",
      "Epoch: 039/100 | Batch 120/192 | Cost: 1.0926\n",
      "Epoch: 039/100 Train Acc.: 64.78% | Validation Acc.: 63.20%\n",
      "Time elapsed: 23.31 min\n",
      "Epoch: 040/100 | Batch 000/192 | Cost: 0.9488\n",
      "Epoch: 040/100 | Batch 120/192 | Cost: 1.0708\n",
      "Epoch: 040/100 Train Acc.: 64.24% | Validation Acc.: 61.30%\n",
      "Time elapsed: 23.91 min\n",
      "Epoch: 041/100 | Batch 000/192 | Cost: 1.0845\n",
      "Epoch: 041/100 | Batch 120/192 | Cost: 1.0342\n",
      "Epoch: 041/100 Train Acc.: 65.33% | Validation Acc.: 65.30%\n",
      "Time elapsed: 24.50 min\n",
      "Epoch: 042/100 | Batch 000/192 | Cost: 0.9705\n",
      "Epoch: 042/100 | Batch 120/192 | Cost: 0.9504\n",
      "Epoch: 042/100 Train Acc.: 65.32% | Validation Acc.: 63.70%\n",
      "Time elapsed: 25.10 min\n",
      "Epoch: 043/100 | Batch 000/192 | Cost: 0.9865\n",
      "Epoch: 043/100 | Batch 120/192 | Cost: 0.9781\n",
      "Epoch: 043/100 Train Acc.: 65.94% | Validation Acc.: 67.00%\n",
      "Time elapsed: 25.70 min\n",
      "Epoch: 044/100 | Batch 000/192 | Cost: 0.9381\n",
      "Epoch: 044/100 | Batch 120/192 | Cost: 0.8644\n",
      "Epoch: 044/100 Train Acc.: 65.39% | Validation Acc.: 63.10%\n",
      "Time elapsed: 26.30 min\n",
      "Epoch: 045/100 | Batch 000/192 | Cost: 0.9364\n",
      "Epoch: 045/100 | Batch 120/192 | Cost: 0.9044\n",
      "Epoch: 045/100 Train Acc.: 66.59% | Validation Acc.: 63.90%\n",
      "Time elapsed: 26.90 min\n",
      "Epoch: 046/100 | Batch 000/192 | Cost: 1.0168\n",
      "Epoch: 046/100 | Batch 120/192 | Cost: 1.0457\n",
      "Epoch: 046/100 Train Acc.: 67.00% | Validation Acc.: 64.60%\n",
      "Time elapsed: 27.49 min\n",
      "Epoch: 047/100 | Batch 000/192 | Cost: 0.9355\n",
      "Epoch: 047/100 | Batch 120/192 | Cost: 0.8651\n",
      "Epoch: 047/100 Train Acc.: 66.29% | Validation Acc.: 63.80%\n",
      "Time elapsed: 28.10 min\n",
      "Epoch: 048/100 | Batch 000/192 | Cost: 0.8658\n",
      "Epoch: 048/100 | Batch 120/192 | Cost: 0.9858\n",
      "Epoch: 048/100 Train Acc.: 67.62% | Validation Acc.: 66.30%\n",
      "Time elapsed: 28.69 min\n",
      "Epoch: 049/100 | Batch 000/192 | Cost: 0.9450\n",
      "Epoch: 049/100 | Batch 120/192 | Cost: 0.8909\n",
      "Epoch: 049/100 Train Acc.: 65.02% | Validation Acc.: 62.30%\n",
      "Time elapsed: 29.29 min\n",
      "Epoch: 050/100 | Batch 000/192 | Cost: 1.0726\n",
      "Epoch: 050/100 | Batch 120/192 | Cost: 0.8521\n",
      "Epoch: 050/100 Train Acc.: 65.28% | Validation Acc.: 61.80%\n",
      "Time elapsed: 29.89 min\n",
      "Epoch: 051/100 | Batch 000/192 | Cost: 1.0874\n",
      "Epoch: 051/100 | Batch 120/192 | Cost: 0.8756\n",
      "Epoch: 051/100 Train Acc.: 67.28% | Validation Acc.: 66.20%\n",
      "Time elapsed: 30.49 min\n",
      "Epoch: 052/100 | Batch 000/192 | Cost: 0.8616\n",
      "Epoch: 052/100 | Batch 120/192 | Cost: 0.8323\n",
      "Epoch: 052/100 Train Acc.: 67.98% | Validation Acc.: 66.70%\n",
      "Time elapsed: 31.09 min\n",
      "Epoch: 053/100 | Batch 000/192 | Cost: 0.8961\n",
      "Epoch: 053/100 | Batch 120/192 | Cost: 0.8221\n",
      "Epoch: 053/100 Train Acc.: 67.47% | Validation Acc.: 67.00%\n",
      "Time elapsed: 31.68 min\n",
      "Epoch: 054/100 | Batch 000/192 | Cost: 1.0299\n",
      "Epoch: 054/100 | Batch 120/192 | Cost: 0.9036\n",
      "Epoch: 054/100 Train Acc.: 68.66% | Validation Acc.: 66.30%\n",
      "Time elapsed: 32.28 min\n",
      "Epoch: 055/100 | Batch 000/192 | Cost: 0.9415\n",
      "Epoch: 055/100 | Batch 120/192 | Cost: 0.9244\n",
      "Epoch: 055/100 Train Acc.: 69.47% | Validation Acc.: 67.60%\n",
      "Time elapsed: 32.88 min\n",
      "Epoch: 056/100 | Batch 000/192 | Cost: 0.9208\n",
      "Epoch: 056/100 | Batch 120/192 | Cost: 0.8804\n",
      "Epoch: 056/100 Train Acc.: 68.89% | Validation Acc.: 65.20%\n",
      "Time elapsed: 33.48 min\n",
      "Epoch: 057/100 | Batch 000/192 | Cost: 0.8637\n",
      "Epoch: 057/100 | Batch 120/192 | Cost: 0.8408\n",
      "Epoch: 057/100 Train Acc.: 69.76% | Validation Acc.: 65.50%\n",
      "Time elapsed: 34.08 min\n",
      "Epoch: 058/100 | Batch 000/192 | Cost: 0.8583\n",
      "Epoch: 058/100 | Batch 120/192 | Cost: 0.8561\n",
      "Epoch: 058/100 Train Acc.: 68.19% | Validation Acc.: 67.10%\n",
      "Time elapsed: 34.68 min\n",
      "Epoch: 059/100 | Batch 000/192 | Cost: 1.0149\n",
      "Epoch: 059/100 | Batch 120/192 | Cost: 0.8976\n",
      "Epoch: 059/100 Train Acc.: 69.50% | Validation Acc.: 67.40%\n",
      "Time elapsed: 35.28 min\n",
      "Epoch: 060/100 | Batch 000/192 | Cost: 0.8902\n",
      "Epoch: 060/100 | Batch 120/192 | Cost: 0.9281\n",
      "Epoch: 060/100 Train Acc.: 70.08% | Validation Acc.: 67.60%\n",
      "Time elapsed: 35.87 min\n",
      "Epoch: 061/100 | Batch 000/192 | Cost: 0.8918\n",
      "Epoch: 061/100 | Batch 120/192 | Cost: 0.8235\n",
      "Epoch: 061/100 Train Acc.: 71.08% | Validation Acc.: 66.80%\n",
      "Time elapsed: 36.47 min\n",
      "Epoch: 062/100 | Batch 000/192 | Cost: 0.8430\n",
      "Epoch: 062/100 | Batch 120/192 | Cost: 0.9288\n",
      "Epoch: 062/100 Train Acc.: 69.94% | Validation Acc.: 67.90%\n",
      "Time elapsed: 37.07 min\n",
      "Epoch: 063/100 | Batch 000/192 | Cost: 0.8885\n",
      "Epoch: 063/100 | Batch 120/192 | Cost: 0.9188\n",
      "Epoch: 063/100 Train Acc.: 69.76% | Validation Acc.: 67.50%\n",
      "Time elapsed: 37.66 min\n",
      "Epoch: 064/100 | Batch 000/192 | Cost: 0.7734\n",
      "Epoch: 064/100 | Batch 120/192 | Cost: 0.8158\n",
      "Epoch: 064/100 Train Acc.: 68.91% | Validation Acc.: 67.60%\n",
      "Time elapsed: 38.26 min\n",
      "Epoch: 065/100 | Batch 000/192 | Cost: 0.8541\n",
      "Epoch: 065/100 | Batch 120/192 | Cost: 0.8135\n",
      "Epoch: 065/100 Train Acc.: 70.94% | Validation Acc.: 67.80%\n",
      "Time elapsed: 38.86 min\n",
      "Epoch: 066/100 | Batch 000/192 | Cost: 0.8057\n",
      "Epoch: 066/100 | Batch 120/192 | Cost: 0.9248\n",
      "Epoch: 066/100 Train Acc.: 71.24% | Validation Acc.: 68.10%\n",
      "Time elapsed: 39.46 min\n",
      "Epoch: 067/100 | Batch 000/192 | Cost: 0.7940\n",
      "Epoch: 067/100 | Batch 120/192 | Cost: 0.8346\n",
      "Epoch: 067/100 Train Acc.: 71.26% | Validation Acc.: 68.00%\n",
      "Time elapsed: 40.05 min\n",
      "Epoch: 068/100 | Batch 000/192 | Cost: 0.7323\n",
      "Epoch: 068/100 | Batch 120/192 | Cost: 0.8404\n",
      "Epoch: 068/100 Train Acc.: 71.13% | Validation Acc.: 69.20%\n",
      "Time elapsed: 40.65 min\n",
      "Epoch: 069/100 | Batch 000/192 | Cost: 0.8699\n",
      "Epoch: 069/100 | Batch 120/192 | Cost: 0.9162\n",
      "Epoch: 069/100 Train Acc.: 68.77% | Validation Acc.: 65.40%\n",
      "Time elapsed: 41.25 min\n",
      "Epoch: 070/100 | Batch 000/192 | Cost: 0.8281\n",
      "Epoch: 070/100 | Batch 120/192 | Cost: 0.7858\n",
      "Epoch: 070/100 Train Acc.: 72.29% | Validation Acc.: 68.50%\n",
      "Time elapsed: 41.85 min\n",
      "Epoch: 071/100 | Batch 000/192 | Cost: 0.8033\n",
      "Epoch: 071/100 | Batch 120/192 | Cost: 0.7622\n",
      "Epoch: 071/100 Train Acc.: 70.89% | Validation Acc.: 68.20%\n",
      "Time elapsed: 42.44 min\n",
      "Epoch: 072/100 | Batch 000/192 | Cost: 0.8036\n",
      "Epoch: 072/100 | Batch 120/192 | Cost: 0.7706\n",
      "Epoch: 072/100 Train Acc.: 71.73% | Validation Acc.: 69.30%\n",
      "Time elapsed: 43.04 min\n",
      "Epoch: 073/100 | Batch 000/192 | Cost: 0.8105\n",
      "Epoch: 073/100 | Batch 120/192 | Cost: 0.7701\n",
      "Epoch: 073/100 Train Acc.: 72.28% | Validation Acc.: 68.50%\n",
      "Time elapsed: 43.64 min\n",
      "Epoch: 074/100 | Batch 000/192 | Cost: 0.8407\n",
      "Epoch: 074/100 | Batch 120/192 | Cost: 0.7073\n",
      "Epoch: 074/100 Train Acc.: 71.31% | Validation Acc.: 66.80%\n",
      "Time elapsed: 44.23 min\n",
      "Epoch: 075/100 | Batch 000/192 | Cost: 0.8167\n",
      "Epoch: 075/100 | Batch 120/192 | Cost: 0.8992\n",
      "Epoch: 075/100 Train Acc.: 70.88% | Validation Acc.: 67.70%\n",
      "Time elapsed: 44.83 min\n",
      "Epoch: 076/100 | Batch 000/192 | Cost: 0.8213\n",
      "Epoch: 076/100 | Batch 120/192 | Cost: 0.7992\n",
      "Epoch: 076/100 Train Acc.: 73.29% | Validation Acc.: 69.70%\n",
      "Time elapsed: 45.43 min\n",
      "Epoch: 077/100 | Batch 000/192 | Cost: 0.7490\n",
      "Epoch: 077/100 | Batch 120/192 | Cost: 0.8266\n",
      "Epoch: 077/100 Train Acc.: 71.30% | Validation Acc.: 69.50%\n",
      "Time elapsed: 46.03 min\n",
      "Epoch: 078/100 | Batch 000/192 | Cost: 0.8243\n",
      "Epoch: 078/100 | Batch 120/192 | Cost: 0.8135\n",
      "Epoch: 078/100 Train Acc.: 72.05% | Validation Acc.: 69.60%\n",
      "Time elapsed: 46.62 min\n",
      "Epoch: 079/100 | Batch 000/192 | Cost: 0.8013\n",
      "Epoch: 079/100 | Batch 120/192 | Cost: 0.7685\n",
      "Epoch: 079/100 Train Acc.: 71.57% | Validation Acc.: 69.10%\n",
      "Time elapsed: 47.22 min\n",
      "Epoch: 080/100 | Batch 000/192 | Cost: 0.7319\n",
      "Epoch: 080/100 | Batch 120/192 | Cost: 0.8254\n",
      "Epoch: 080/100 Train Acc.: 74.19% | Validation Acc.: 71.50%\n",
      "Time elapsed: 47.82 min\n",
      "Epoch: 081/100 | Batch 000/192 | Cost: 0.6880\n",
      "Epoch: 081/100 | Batch 120/192 | Cost: 0.8099\n",
      "Epoch: 081/100 Train Acc.: 73.16% | Validation Acc.: 70.70%\n",
      "Time elapsed: 48.42 min\n",
      "Epoch: 082/100 | Batch 000/192 | Cost: 0.6497\n",
      "Epoch: 082/100 | Batch 120/192 | Cost: 0.8208\n",
      "Epoch: 082/100 Train Acc.: 72.91% | Validation Acc.: 71.60%\n",
      "Time elapsed: 49.01 min\n",
      "Epoch: 083/100 | Batch 000/192 | Cost: 0.6912\n",
      "Epoch: 083/100 | Batch 120/192 | Cost: 0.7313\n",
      "Epoch: 083/100 Train Acc.: 74.04% | Validation Acc.: 70.10%\n",
      "Time elapsed: 49.61 min\n",
      "Epoch: 084/100 | Batch 000/192 | Cost: 0.7115\n",
      "Epoch: 084/100 | Batch 120/192 | Cost: 0.7145\n",
      "Epoch: 084/100 Train Acc.: 74.19% | Validation Acc.: 71.40%\n",
      "Time elapsed: 50.21 min\n",
      "Epoch: 085/100 | Batch 000/192 | Cost: 0.7889\n",
      "Epoch: 085/100 | Batch 120/192 | Cost: 0.6902\n",
      "Epoch: 085/100 Train Acc.: 73.75% | Validation Acc.: 69.70%\n",
      "Time elapsed: 50.80 min\n",
      "Epoch: 086/100 | Batch 000/192 | Cost: 0.7790\n",
      "Epoch: 086/100 | Batch 120/192 | Cost: 0.7807\n",
      "Epoch: 086/100 Train Acc.: 74.25% | Validation Acc.: 70.50%\n",
      "Time elapsed: 51.40 min\n",
      "Epoch: 087/100 | Batch 000/192 | Cost: 0.6930\n",
      "Epoch: 087/100 | Batch 120/192 | Cost: 0.7695\n",
      "Epoch: 087/100 Train Acc.: 73.66% | Validation Acc.: 71.30%\n",
      "Time elapsed: 52.00 min\n",
      "Epoch: 088/100 | Batch 000/192 | Cost: 0.7779\n",
      "Epoch: 088/100 | Batch 120/192 | Cost: 0.7066\n",
      "Epoch: 088/100 Train Acc.: 74.70% | Validation Acc.: 70.70%\n",
      "Time elapsed: 52.60 min\n",
      "Epoch: 089/100 | Batch 000/192 | Cost: 0.7385\n",
      "Epoch: 089/100 | Batch 120/192 | Cost: 0.7996\n",
      "Epoch: 089/100 Train Acc.: 73.74% | Validation Acc.: 69.80%\n",
      "Time elapsed: 53.20 min\n",
      "Epoch: 090/100 | Batch 000/192 | Cost: 0.7230\n",
      "Epoch: 090/100 | Batch 120/192 | Cost: 0.7889\n",
      "Epoch: 090/100 Train Acc.: 75.10% | Validation Acc.: 71.70%\n",
      "Time elapsed: 53.80 min\n",
      "Epoch: 091/100 | Batch 000/192 | Cost: 0.7388\n",
      "Epoch: 091/100 | Batch 120/192 | Cost: 0.7569\n",
      "Epoch: 091/100 Train Acc.: 75.22% | Validation Acc.: 71.90%\n",
      "Time elapsed: 54.39 min\n",
      "Epoch: 092/100 | Batch 000/192 | Cost: 0.6327\n",
      "Epoch: 092/100 | Batch 120/192 | Cost: 0.6937\n",
      "Epoch: 092/100 Train Acc.: 75.02% | Validation Acc.: 72.40%\n",
      "Time elapsed: 54.99 min\n",
      "Epoch: 093/100 | Batch 000/192 | Cost: 0.7830\n",
      "Epoch: 093/100 | Batch 120/192 | Cost: 0.6699\n",
      "Epoch: 093/100 Train Acc.: 75.35% | Validation Acc.: 71.20%\n",
      "Time elapsed: 55.59 min\n",
      "Epoch: 094/100 | Batch 000/192 | Cost: 0.7633\n",
      "Epoch: 094/100 | Batch 120/192 | Cost: 0.7601\n",
      "Epoch: 094/100 Train Acc.: 74.38% | Validation Acc.: 71.20%\n",
      "Time elapsed: 56.19 min\n",
      "Epoch: 095/100 | Batch 000/192 | Cost: 0.6946\n",
      "Epoch: 095/100 | Batch 120/192 | Cost: 0.8123\n",
      "Epoch: 095/100 Train Acc.: 75.07% | Validation Acc.: 70.90%\n",
      "Time elapsed: 56.79 min\n",
      "Epoch: 096/100 | Batch 000/192 | Cost: 0.6740\n",
      "Epoch: 096/100 | Batch 120/192 | Cost: 0.6667\n",
      "Epoch: 096/100 Train Acc.: 75.40% | Validation Acc.: 73.30%\n",
      "Time elapsed: 57.38 min\n",
      "Epoch: 097/100 | Batch 000/192 | Cost: 0.6683\n",
      "Epoch: 097/100 | Batch 120/192 | Cost: 0.6792\n",
      "Epoch: 097/100 Train Acc.: 73.43% | Validation Acc.: 68.00%\n",
      "Time elapsed: 57.98 min\n",
      "Epoch: 098/100 | Batch 000/192 | Cost: 0.8212\n",
      "Epoch: 098/100 | Batch 120/192 | Cost: 0.8133\n",
      "Epoch: 098/100 Train Acc.: 75.72% | Validation Acc.: 72.80%\n",
      "Time elapsed: 58.58 min\n",
      "Epoch: 099/100 | Batch 000/192 | Cost: 0.7529\n",
      "Epoch: 099/100 | Batch 120/192 | Cost: 0.7609\n",
      "Epoch: 099/100 Train Acc.: 73.84% | Validation Acc.: 70.60%\n",
      "Time elapsed: 59.18 min\n",
      "Epoch: 100/100 | Batch 000/192 | Cost: 0.8095\n",
      "Epoch: 100/100 | Batch 120/192 | Cost: 0.7173\n",
      "Epoch: 100/100 Train Acc.: 75.09% | Validation Acc.: 72.10%\n",
      "Time elapsed: 59.78 min\n",
      "Total Training Time: 59.78 min\n"
     ]
    }
   ],
   "source": [
    "def compute_accuracy(model, data_loader, device):\n",
    "    correct_pred, num_examples = 0, 0\n",
    "    for i, (features, targets) in enumerate(data_loader):\n",
    "            \n",
    "        features = features.to(device)\n",
    "        targets = targets.to(device)\n",
    "\n",
    "        logits, probas = model(features)\n",
    "        _, predicted_labels = torch.max(probas, 1)\n",
    "        num_examples += targets.size(0)\n",
    "        correct_pred += (predicted_labels == targets).sum()\n",
    "    return correct_pred.float()/num_examples * 100\n",
    "    \n",
    "\n",
    "start_time = time.time()\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    \n",
    "    model.train()\n",
    "    \n",
    "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
    "    \n",
    "        ### PREPARE MINIBATCH\n",
    "        features = features.to(DEVICE)\n",
    "        targets = targets.to(DEVICE)\n",
    "            \n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits, probas = model(features)\n",
    "        cost = F.cross_entropy(logits, targets)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 120:\n",
    "            print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n",
    "                   f'Batch {batch_idx:03d}/{len(train_loader):03d} |' \n",
    "                   f' Cost: {cost:.4f}')\n",
    "\n",
    "    # no need to build the computation graph for backprop when computing accuracy\n",
    "    with torch.set_grad_enabled(False):\n",
    "        train_acc = compute_accuracy(model, train_loader, device=DEVICE)\n",
    "        valid_acc = compute_accuracy(model, valid_loader, device=DEVICE)\n",
    "        print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} Train Acc.: {train_acc:.2f}%'\n",
    "              f' | Validation Acc.: {valid_acc:.2f}%')\n",
    "        \n",
    "    elapsed = (time.time() - start_time)/60\n",
    "    print(f'Time elapsed: {elapsed:.2f} min')\n",
    "  \n",
    "elapsed = (time.time() - start_time)/60\n",
    "print(f'Total Training Time: {elapsed:.2f} min')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Batch Normalization (for comparison)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### MODEL\n",
    "##########################\n",
    "\n",
    "\n",
    "class NiN(nn.Module):\n",
    "    def __init__(self, num_classes):\n",
    "        super(NiN, self).__init__()\n",
    "        self.num_classes = num_classes\n",
    "        self.classifier = nn.Sequential(\n",
    "                nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2, bias=False),\n",
    "                nn.BatchNorm2d(192),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(192, 160, kernel_size=1, stride=1, padding=0, bias=False),\n",
    "                nn.BatchNorm2d(160),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(160,  96, kernel_size=1, stride=1, padding=0, bias=False),\n",
    "                nn.BatchNorm2d(96),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),\n",
    "                nn.Dropout(0.5),\n",
    "\n",
    "                nn.Conv2d(96, 192, kernel_size=5, stride=1, padding=2, bias=False),\n",
    "                nn.BatchNorm2d(192),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0, bias=False),\n",
    "                nn.BatchNorm2d(192),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0, bias=False),\n",
    "                nn.BatchNorm2d(192),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.AvgPool2d(kernel_size=3, stride=2, padding=1),\n",
    "                nn.Dropout(0.5),\n",
    "\n",
    "                nn.Conv2d(192, 192, kernel_size=3, stride=1, padding=1, bias=False),\n",
    "                nn.BatchNorm2d(192),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0, bias=False),\n",
    "                nn.BatchNorm2d(192),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(192,  10, kernel_size=1, stride=1, padding=0),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.AvgPool2d(kernel_size=8, stride=1, padding=0),\n",
    "\n",
    "                )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.classifier(x)\n",
    "        logits = x.view(x.size(0), self.num_classes)\n",
    "        probas = torch.softmax(logits, dim=1)\n",
    "        return logits, probas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/100 | Batch 000/192 | Cost: 0.6825\n",
      "Epoch: 001/100 | Batch 120/192 | Cost: 0.7027\n",
      "Epoch: 001/100 Train Acc.: 76.08% | Validation Acc.: 72.60%\n",
      "Time elapsed: 0.60 min\n",
      "Epoch: 002/100 | Batch 000/192 | Cost: 0.6714\n",
      "Epoch: 002/100 | Batch 120/192 | Cost: 0.6445\n",
      "Epoch: 002/100 Train Acc.: 76.35% | Validation Acc.: 73.40%\n",
      "Time elapsed: 1.20 min\n",
      "Epoch: 003/100 | Batch 000/192 | Cost: 0.7066\n",
      "Epoch: 003/100 | Batch 120/192 | Cost: 0.6379\n",
      "Epoch: 003/100 Train Acc.: 76.63% | Validation Acc.: 73.80%\n",
      "Time elapsed: 1.80 min\n",
      "Epoch: 004/100 | Batch 000/192 | Cost: 0.7371\n",
      "Epoch: 004/100 | Batch 120/192 | Cost: 0.7181\n",
      "Epoch: 004/100 Train Acc.: 75.50% | Validation Acc.: 70.80%\n",
      "Time elapsed: 2.40 min\n",
      "Epoch: 005/100 | Batch 000/192 | Cost: 0.7634\n",
      "Epoch: 005/100 | Batch 120/192 | Cost: 0.6887\n",
      "Epoch: 005/100 Train Acc.: 77.08% | Validation Acc.: 73.50%\n",
      "Time elapsed: 3.00 min\n",
      "Epoch: 006/100 | Batch 000/192 | Cost: 0.6509\n",
      "Epoch: 006/100 | Batch 120/192 | Cost: 0.7140\n",
      "Epoch: 006/100 Train Acc.: 76.52% | Validation Acc.: 73.70%\n",
      "Time elapsed: 3.60 min\n",
      "Epoch: 007/100 | Batch 000/192 | Cost: 0.7971\n",
      "Epoch: 007/100 | Batch 120/192 | Cost: 0.6890\n",
      "Epoch: 007/100 Train Acc.: 76.42% | Validation Acc.: 72.60%\n",
      "Time elapsed: 4.20 min\n",
      "Epoch: 008/100 | Batch 000/192 | Cost: 0.6466\n",
      "Epoch: 008/100 | Batch 120/192 | Cost: 0.5831\n",
      "Epoch: 008/100 Train Acc.: 76.12% | Validation Acc.: 72.60%\n",
      "Time elapsed: 4.80 min\n",
      "Epoch: 009/100 | Batch 000/192 | Cost: 0.6249\n",
      "Epoch: 009/100 | Batch 120/192 | Cost: 0.7722\n",
      "Epoch: 009/100 Train Acc.: 76.25% | Validation Acc.: 73.00%\n",
      "Time elapsed: 5.40 min\n",
      "Epoch: 010/100 | Batch 000/192 | Cost: 0.6890\n",
      "Epoch: 010/100 | Batch 120/192 | Cost: 0.6893\n",
      "Epoch: 010/100 Train Acc.: 76.44% | Validation Acc.: 72.30%\n",
      "Time elapsed: 6.00 min\n",
      "Epoch: 011/100 | Batch 000/192 | Cost: 0.5889\n",
      "Epoch: 011/100 | Batch 120/192 | Cost: 0.6722\n",
      "Epoch: 011/100 Train Acc.: 75.29% | Validation Acc.: 71.20%\n",
      "Time elapsed: 6.60 min\n",
      "Epoch: 012/100 | Batch 000/192 | Cost: 0.7587\n",
      "Epoch: 012/100 | Batch 120/192 | Cost: 0.6418\n",
      "Epoch: 012/100 Train Acc.: 74.49% | Validation Acc.: 70.20%\n",
      "Time elapsed: 7.20 min\n",
      "Epoch: 013/100 | Batch 000/192 | Cost: 0.7850\n",
      "Epoch: 013/100 | Batch 120/192 | Cost: 0.6407\n",
      "Epoch: 013/100 Train Acc.: 76.05% | Validation Acc.: 71.00%\n",
      "Time elapsed: 7.80 min\n",
      "Epoch: 014/100 | Batch 000/192 | Cost: 0.8101\n",
      "Epoch: 014/100 | Batch 120/192 | Cost: 0.6613\n",
      "Epoch: 014/100 Train Acc.: 78.08% | Validation Acc.: 75.60%\n",
      "Time elapsed: 8.40 min\n",
      "Epoch: 015/100 | Batch 000/192 | Cost: 0.6271\n",
      "Epoch: 015/100 | Batch 120/192 | Cost: 0.6263\n",
      "Epoch: 015/100 Train Acc.: 77.98% | Validation Acc.: 73.80%\n",
      "Time elapsed: 9.00 min\n",
      "Epoch: 016/100 | Batch 000/192 | Cost: 0.6320\n",
      "Epoch: 016/100 | Batch 120/192 | Cost: 0.7301\n",
      "Epoch: 016/100 Train Acc.: 78.47% | Validation Acc.: 74.30%\n",
      "Time elapsed: 9.59 min\n",
      "Epoch: 017/100 | Batch 000/192 | Cost: 0.6161\n",
      "Epoch: 017/100 | Batch 120/192 | Cost: 0.5296\n",
      "Epoch: 017/100 Train Acc.: 75.61% | Validation Acc.: 71.90%\n",
      "Time elapsed: 10.19 min\n",
      "Epoch: 018/100 | Batch 000/192 | Cost: 0.7155\n",
      "Epoch: 018/100 | Batch 120/192 | Cost: 0.6337\n",
      "Epoch: 018/100 Train Acc.: 78.82% | Validation Acc.: 73.20%\n",
      "Time elapsed: 10.79 min\n",
      "Epoch: 019/100 | Batch 000/192 | Cost: 0.5872\n",
      "Epoch: 019/100 | Batch 120/192 | Cost: 0.7798\n",
      "Epoch: 019/100 Train Acc.: 77.89% | Validation Acc.: 71.50%\n",
      "Time elapsed: 11.39 min\n",
      "Epoch: 020/100 | Batch 000/192 | Cost: 0.6479\n",
      "Epoch: 020/100 | Batch 120/192 | Cost: 0.7479\n",
      "Epoch: 020/100 Train Acc.: 77.54% | Validation Acc.: 73.90%\n",
      "Time elapsed: 11.99 min\n",
      "Epoch: 021/100 | Batch 000/192 | Cost: 0.6184\n",
      "Epoch: 021/100 | Batch 120/192 | Cost: 0.5969\n",
      "Epoch: 021/100 Train Acc.: 78.09% | Validation Acc.: 73.10%\n",
      "Time elapsed: 12.59 min\n",
      "Epoch: 022/100 | Batch 000/192 | Cost: 0.6987\n",
      "Epoch: 022/100 | Batch 120/192 | Cost: 0.6107\n",
      "Epoch: 022/100 Train Acc.: 79.20% | Validation Acc.: 75.40%\n",
      "Time elapsed: 13.19 min\n",
      "Epoch: 023/100 | Batch 000/192 | Cost: 0.6745\n",
      "Epoch: 023/100 | Batch 120/192 | Cost: 0.6630\n",
      "Epoch: 023/100 Train Acc.: 77.09% | Validation Acc.: 73.50%\n",
      "Time elapsed: 13.79 min\n",
      "Epoch: 024/100 | Batch 000/192 | Cost: 0.5587\n",
      "Epoch: 024/100 | Batch 120/192 | Cost: 0.7128\n",
      "Epoch: 024/100 Train Acc.: 78.96% | Validation Acc.: 75.10%\n",
      "Time elapsed: 14.39 min\n",
      "Epoch: 025/100 | Batch 000/192 | Cost: 0.5177\n",
      "Epoch: 025/100 | Batch 120/192 | Cost: 0.5650\n",
      "Epoch: 025/100 Train Acc.: 78.93% | Validation Acc.: 74.90%\n",
      "Time elapsed: 14.99 min\n",
      "Epoch: 026/100 | Batch 000/192 | Cost: 0.5755\n",
      "Epoch: 026/100 | Batch 120/192 | Cost: 0.5524\n",
      "Epoch: 026/100 Train Acc.: 78.21% | Validation Acc.: 72.10%\n",
      "Time elapsed: 15.59 min\n",
      "Epoch: 027/100 | Batch 000/192 | Cost: 0.6927\n",
      "Epoch: 027/100 | Batch 120/192 | Cost: 0.6578\n",
      "Epoch: 027/100 Train Acc.: 77.63% | Validation Acc.: 74.20%\n",
      "Time elapsed: 16.19 min\n",
      "Epoch: 028/100 | Batch 000/192 | Cost: 0.5719\n",
      "Epoch: 028/100 | Batch 120/192 | Cost: 0.5817\n",
      "Epoch: 028/100 Train Acc.: 79.22% | Validation Acc.: 75.80%\n",
      "Time elapsed: 16.78 min\n",
      "Epoch: 029/100 | Batch 000/192 | Cost: 0.7459\n",
      "Epoch: 029/100 | Batch 120/192 | Cost: 0.5700\n",
      "Epoch: 029/100 Train Acc.: 78.72% | Validation Acc.: 73.20%\n",
      "Time elapsed: 17.38 min\n",
      "Epoch: 030/100 | Batch 000/192 | Cost: 0.6114\n",
      "Epoch: 030/100 | Batch 120/192 | Cost: 0.6975\n",
      "Epoch: 030/100 Train Acc.: 79.26% | Validation Acc.: 74.50%\n",
      "Time elapsed: 17.98 min\n",
      "Epoch: 031/100 | Batch 000/192 | Cost: 0.6240\n",
      "Epoch: 031/100 | Batch 120/192 | Cost: 0.5819\n",
      "Epoch: 031/100 Train Acc.: 78.59% | Validation Acc.: 73.50%\n",
      "Time elapsed: 18.58 min\n",
      "Epoch: 032/100 | Batch 000/192 | Cost: 0.6015\n",
      "Epoch: 032/100 | Batch 120/192 | Cost: 0.6598\n",
      "Epoch: 032/100 Train Acc.: 79.42% | Validation Acc.: 74.90%\n",
      "Time elapsed: 19.18 min\n",
      "Epoch: 033/100 | Batch 000/192 | Cost: 0.5757\n",
      "Epoch: 033/100 | Batch 120/192 | Cost: 0.5761\n",
      "Epoch: 033/100 Train Acc.: 79.55% | Validation Acc.: 73.30%\n",
      "Time elapsed: 19.78 min\n",
      "Epoch: 034/100 | Batch 000/192 | Cost: 0.5700\n",
      "Epoch: 034/100 | Batch 120/192 | Cost: 0.5877\n",
      "Epoch: 034/100 Train Acc.: 78.55% | Validation Acc.: 74.30%\n",
      "Time elapsed: 20.38 min\n",
      "Epoch: 035/100 | Batch 000/192 | Cost: 0.7023\n",
      "Epoch: 035/100 | Batch 120/192 | Cost: 0.6592\n",
      "Epoch: 035/100 Train Acc.: 79.70% | Validation Acc.: 74.50%\n",
      "Time elapsed: 20.98 min\n",
      "Epoch: 036/100 | Batch 000/192 | Cost: 0.6325\n",
      "Epoch: 036/100 | Batch 120/192 | Cost: 0.5737\n",
      "Epoch: 036/100 Train Acc.: 78.72% | Validation Acc.: 74.00%\n",
      "Time elapsed: 21.58 min\n",
      "Epoch: 037/100 | Batch 000/192 | Cost: 0.6329\n",
      "Epoch: 037/100 | Batch 120/192 | Cost: 0.6539\n",
      "Epoch: 037/100 Train Acc.: 79.14% | Validation Acc.: 75.80%\n",
      "Time elapsed: 22.18 min\n",
      "Epoch: 038/100 | Batch 000/192 | Cost: 0.6387\n",
      "Epoch: 038/100 | Batch 120/192 | Cost: 0.7099\n",
      "Epoch: 038/100 Train Acc.: 79.74% | Validation Acc.: 75.80%\n",
      "Time elapsed: 22.78 min\n",
      "Epoch: 039/100 | Batch 000/192 | Cost: 0.5981\n",
      "Epoch: 039/100 | Batch 120/192 | Cost: 0.6197\n",
      "Epoch: 039/100 Train Acc.: 79.78% | Validation Acc.: 75.20%\n",
      "Time elapsed: 23.38 min\n",
      "Epoch: 040/100 | Batch 000/192 | Cost: 0.5893\n",
      "Epoch: 040/100 | Batch 120/192 | Cost: 0.6424\n",
      "Epoch: 040/100 Train Acc.: 79.92% | Validation Acc.: 75.40%\n",
      "Time elapsed: 23.97 min\n",
      "Epoch: 041/100 | Batch 000/192 | Cost: 0.5046\n",
      "Epoch: 041/100 | Batch 120/192 | Cost: 0.6603\n",
      "Epoch: 041/100 Train Acc.: 80.10% | Validation Acc.: 76.90%\n",
      "Time elapsed: 24.57 min\n",
      "Epoch: 042/100 | Batch 000/192 | Cost: 0.5880\n",
      "Epoch: 042/100 | Batch 120/192 | Cost: 0.6180\n",
      "Epoch: 042/100 Train Acc.: 80.03% | Validation Acc.: 75.80%\n",
      "Time elapsed: 25.17 min\n",
      "Epoch: 043/100 | Batch 000/192 | Cost: 0.5328\n",
      "Epoch: 043/100 | Batch 120/192 | Cost: 0.5885\n",
      "Epoch: 043/100 Train Acc.: 80.61% | Validation Acc.: 75.30%\n",
      "Time elapsed: 25.77 min\n",
      "Epoch: 044/100 | Batch 000/192 | Cost: 0.5354\n",
      "Epoch: 044/100 | Batch 120/192 | Cost: 0.6034\n",
      "Epoch: 044/100 Train Acc.: 79.90% | Validation Acc.: 75.40%\n",
      "Time elapsed: 26.37 min\n",
      "Epoch: 045/100 | Batch 000/192 | Cost: 0.5417\n",
      "Epoch: 045/100 | Batch 120/192 | Cost: 0.5952\n",
      "Epoch: 045/100 Train Acc.: 80.90% | Validation Acc.: 76.70%\n",
      "Time elapsed: 26.97 min\n",
      "Epoch: 046/100 | Batch 000/192 | Cost: 0.5907\n",
      "Epoch: 046/100 | Batch 120/192 | Cost: 0.5139\n",
      "Epoch: 046/100 Train Acc.: 80.51% | Validation Acc.: 75.30%\n",
      "Time elapsed: 27.57 min\n",
      "Epoch: 047/100 | Batch 000/192 | Cost: 0.5880\n",
      "Epoch: 047/100 | Batch 120/192 | Cost: 0.6240\n",
      "Epoch: 047/100 Train Acc.: 79.33% | Validation Acc.: 75.70%\n",
      "Time elapsed: 28.17 min\n",
      "Epoch: 048/100 | Batch 000/192 | Cost: 0.5946\n",
      "Epoch: 048/100 | Batch 120/192 | Cost: 0.4510\n",
      "Epoch: 048/100 Train Acc.: 78.83% | Validation Acc.: 74.50%\n",
      "Time elapsed: 28.76 min\n",
      "Epoch: 049/100 | Batch 000/192 | Cost: 0.4876\n",
      "Epoch: 049/100 | Batch 120/192 | Cost: 0.5322\n",
      "Epoch: 049/100 Train Acc.: 79.87% | Validation Acc.: 73.60%\n",
      "Time elapsed: 29.37 min\n",
      "Epoch: 050/100 | Batch 000/192 | Cost: 0.5297\n",
      "Epoch: 050/100 | Batch 120/192 | Cost: 0.5505\n",
      "Epoch: 050/100 Train Acc.: 80.06% | Validation Acc.: 74.70%\n",
      "Time elapsed: 29.96 min\n",
      "Epoch: 051/100 | Batch 000/192 | Cost: 0.5628\n",
      "Epoch: 051/100 | Batch 120/192 | Cost: 0.6086\n",
      "Epoch: 051/100 Train Acc.: 79.69% | Validation Acc.: 74.50%\n",
      "Time elapsed: 30.56 min\n",
      "Epoch: 052/100 | Batch 000/192 | Cost: 0.5363\n",
      "Epoch: 052/100 | Batch 120/192 | Cost: 0.5686\n",
      "Epoch: 052/100 Train Acc.: 79.82% | Validation Acc.: 75.40%\n",
      "Time elapsed: 31.17 min\n",
      "Epoch: 053/100 | Batch 000/192 | Cost: 0.5480\n",
      "Epoch: 053/100 | Batch 120/192 | Cost: 0.6160\n",
      "Epoch: 053/100 Train Acc.: 81.01% | Validation Acc.: 74.90%\n",
      "Time elapsed: 31.77 min\n",
      "Epoch: 054/100 | Batch 000/192 | Cost: 0.6155\n",
      "Epoch: 054/100 | Batch 120/192 | Cost: 0.5478\n",
      "Epoch: 054/100 Train Acc.: 80.68% | Validation Acc.: 76.30%\n",
      "Time elapsed: 32.37 min\n",
      "Epoch: 055/100 | Batch 000/192 | Cost: 0.5578\n",
      "Epoch: 055/100 | Batch 120/192 | Cost: 0.5159\n",
      "Epoch: 055/100 Train Acc.: 81.08% | Validation Acc.: 76.00%\n",
      "Time elapsed: 32.97 min\n",
      "Epoch: 056/100 | Batch 000/192 | Cost: 0.5334\n",
      "Epoch: 056/100 | Batch 120/192 | Cost: 0.4940\n",
      "Epoch: 056/100 Train Acc.: 80.93% | Validation Acc.: 74.80%\n",
      "Time elapsed: 33.56 min\n",
      "Epoch: 057/100 | Batch 000/192 | Cost: 0.4510\n",
      "Epoch: 057/100 | Batch 120/192 | Cost: 0.4787\n",
      "Epoch: 057/100 Train Acc.: 80.99% | Validation Acc.: 77.60%\n",
      "Time elapsed: 34.16 min\n",
      "Epoch: 058/100 | Batch 000/192 | Cost: 0.5937\n",
      "Epoch: 058/100 | Batch 120/192 | Cost: 0.5977\n",
      "Epoch: 058/100 Train Acc.: 81.55% | Validation Acc.: 76.60%\n",
      "Time elapsed: 34.76 min\n",
      "Epoch: 059/100 | Batch 000/192 | Cost: 0.5205\n",
      "Epoch: 059/100 | Batch 120/192 | Cost: 0.5913\n",
      "Epoch: 059/100 Train Acc.: 81.43% | Validation Acc.: 75.50%\n",
      "Time elapsed: 35.36 min\n",
      "Epoch: 060/100 | Batch 000/192 | Cost: 0.5559\n",
      "Epoch: 060/100 | Batch 120/192 | Cost: 0.5472\n",
      "Epoch: 060/100 Train Acc.: 81.28% | Validation Acc.: 76.00%\n",
      "Time elapsed: 35.96 min\n",
      "Epoch: 061/100 | Batch 000/192 | Cost: 0.5187\n",
      "Epoch: 061/100 | Batch 120/192 | Cost: 0.5881\n",
      "Epoch: 061/100 Train Acc.: 80.71% | Validation Acc.: 75.30%\n",
      "Time elapsed: 36.56 min\n",
      "Epoch: 062/100 | Batch 000/192 | Cost: 0.5069\n",
      "Epoch: 062/100 | Batch 120/192 | Cost: 0.4619\n",
      "Epoch: 062/100 Train Acc.: 79.42% | Validation Acc.: 74.50%\n",
      "Time elapsed: 37.16 min\n",
      "Epoch: 063/100 | Batch 000/192 | Cost: 0.5226\n",
      "Epoch: 063/100 | Batch 120/192 | Cost: 0.4582\n",
      "Epoch: 063/100 Train Acc.: 82.28% | Validation Acc.: 76.50%\n",
      "Time elapsed: 37.76 min\n",
      "Epoch: 064/100 | Batch 000/192 | Cost: 0.4503\n",
      "Epoch: 064/100 | Batch 120/192 | Cost: 0.6131\n",
      "Epoch: 064/100 Train Acc.: 81.77% | Validation Acc.: 77.10%\n",
      "Time elapsed: 38.36 min\n",
      "Epoch: 065/100 | Batch 000/192 | Cost: 0.5485\n",
      "Epoch: 065/100 | Batch 120/192 | Cost: 0.4611\n",
      "Epoch: 065/100 Train Acc.: 82.12% | Validation Acc.: 76.90%\n",
      "Time elapsed: 38.96 min\n",
      "Epoch: 066/100 | Batch 000/192 | Cost: 0.5312\n",
      "Epoch: 066/100 | Batch 120/192 | Cost: 0.5280\n",
      "Epoch: 066/100 Train Acc.: 80.55% | Validation Acc.: 75.70%\n",
      "Time elapsed: 39.56 min\n",
      "Epoch: 067/100 | Batch 000/192 | Cost: 0.5594\n",
      "Epoch: 067/100 | Batch 120/192 | Cost: 0.5339\n",
      "Epoch: 067/100 Train Acc.: 82.44% | Validation Acc.: 75.90%\n",
      "Time elapsed: 40.16 min\n",
      "Epoch: 068/100 | Batch 000/192 | Cost: 0.5033\n",
      "Epoch: 068/100 | Batch 120/192 | Cost: 0.5830\n",
      "Epoch: 068/100 Train Acc.: 80.29% | Validation Acc.: 77.00%\n",
      "Time elapsed: 40.76 min\n",
      "Epoch: 069/100 | Batch 000/192 | Cost: 0.5595\n",
      "Epoch: 069/100 | Batch 120/192 | Cost: 0.4824\n",
      "Epoch: 069/100 Train Acc.: 82.34% | Validation Acc.: 76.60%\n",
      "Time elapsed: 41.36 min\n",
      "Epoch: 070/100 | Batch 000/192 | Cost: 0.4599\n",
      "Epoch: 070/100 | Batch 120/192 | Cost: 0.5537\n",
      "Epoch: 070/100 Train Acc.: 80.60% | Validation Acc.: 76.40%\n",
      "Time elapsed: 41.96 min\n",
      "Epoch: 071/100 | Batch 000/192 | Cost: 0.4914\n",
      "Epoch: 071/100 | Batch 120/192 | Cost: 0.4687\n",
      "Epoch: 071/100 Train Acc.: 81.68% | Validation Acc.: 78.70%\n",
      "Time elapsed: 42.56 min\n",
      "Epoch: 072/100 | Batch 000/192 | Cost: 0.5055\n",
      "Epoch: 072/100 | Batch 120/192 | Cost: 0.5267\n",
      "Epoch: 072/100 Train Acc.: 82.67% | Validation Acc.: 78.50%\n",
      "Time elapsed: 43.16 min\n",
      "Epoch: 073/100 | Batch 000/192 | Cost: 0.4431\n",
      "Epoch: 073/100 | Batch 120/192 | Cost: 0.5050\n",
      "Epoch: 073/100 Train Acc.: 80.47% | Validation Acc.: 75.50%\n",
      "Time elapsed: 43.75 min\n",
      "Epoch: 074/100 | Batch 000/192 | Cost: 0.6785\n",
      "Epoch: 074/100 | Batch 120/192 | Cost: 0.4457\n",
      "Epoch: 074/100 Train Acc.: 83.49% | Validation Acc.: 78.10%\n",
      "Time elapsed: 44.35 min\n",
      "Epoch: 075/100 | Batch 000/192 | Cost: 0.4841\n",
      "Epoch: 075/100 | Batch 120/192 | Cost: 0.5260\n",
      "Epoch: 075/100 Train Acc.: 82.56% | Validation Acc.: 76.60%\n",
      "Time elapsed: 44.95 min\n",
      "Epoch: 076/100 | Batch 000/192 | Cost: 0.4382\n",
      "Epoch: 076/100 | Batch 120/192 | Cost: 0.5470\n",
      "Epoch: 076/100 Train Acc.: 80.99% | Validation Acc.: 76.40%\n",
      "Time elapsed: 45.55 min\n",
      "Epoch: 077/100 | Batch 000/192 | Cost: 0.5573\n",
      "Epoch: 077/100 | Batch 120/192 | Cost: 0.5162\n",
      "Epoch: 077/100 Train Acc.: 81.44% | Validation Acc.: 75.30%\n",
      "Time elapsed: 46.15 min\n",
      "Epoch: 078/100 | Batch 000/192 | Cost: 0.6098\n",
      "Epoch: 078/100 | Batch 120/192 | Cost: 0.5280\n",
      "Epoch: 078/100 Train Acc.: 81.36% | Validation Acc.: 77.40%\n",
      "Time elapsed: 46.75 min\n",
      "Epoch: 079/100 | Batch 000/192 | Cost: 0.5197\n",
      "Epoch: 079/100 | Batch 120/192 | Cost: 0.5272\n",
      "Epoch: 079/100 Train Acc.: 83.33% | Validation Acc.: 79.00%\n",
      "Time elapsed: 47.35 min\n",
      "Epoch: 080/100 | Batch 000/192 | Cost: 0.4899\n",
      "Epoch: 080/100 | Batch 120/192 | Cost: 0.4799\n",
      "Epoch: 080/100 Train Acc.: 82.94% | Validation Acc.: 77.20%\n",
      "Time elapsed: 47.95 min\n",
      "Epoch: 081/100 | Batch 000/192 | Cost: 0.5248\n",
      "Epoch: 081/100 | Batch 120/192 | Cost: 0.5740\n",
      "Epoch: 081/100 Train Acc.: 82.80% | Validation Acc.: 76.40%\n",
      "Time elapsed: 48.54 min\n",
      "Epoch: 082/100 | Batch 000/192 | Cost: 0.5918\n",
      "Epoch: 082/100 | Batch 120/192 | Cost: 0.5965\n",
      "Epoch: 082/100 Train Acc.: 82.79% | Validation Acc.: 79.90%\n",
      "Time elapsed: 49.14 min\n",
      "Epoch: 083/100 | Batch 000/192 | Cost: 0.4830\n",
      "Epoch: 083/100 | Batch 120/192 | Cost: 0.5613\n",
      "Epoch: 083/100 Train Acc.: 82.44% | Validation Acc.: 77.20%\n",
      "Time elapsed: 49.74 min\n",
      "Epoch: 084/100 | Batch 000/192 | Cost: 0.4969\n",
      "Epoch: 084/100 | Batch 120/192 | Cost: 0.4625\n",
      "Epoch: 084/100 Train Acc.: 82.37% | Validation Acc.: 77.20%\n",
      "Time elapsed: 50.34 min\n",
      "Epoch: 085/100 | Batch 000/192 | Cost: 0.5664\n",
      "Epoch: 085/100 | Batch 120/192 | Cost: 0.4476\n",
      "Epoch: 085/100 Train Acc.: 83.63% | Validation Acc.: 78.10%\n",
      "Time elapsed: 50.94 min\n",
      "Epoch: 086/100 | Batch 000/192 | Cost: 0.4885\n",
      "Epoch: 086/100 | Batch 120/192 | Cost: 0.5724\n",
      "Epoch: 086/100 Train Acc.: 82.88% | Validation Acc.: 77.10%\n",
      "Time elapsed: 51.54 min\n",
      "Epoch: 087/100 | Batch 000/192 | Cost: 0.4967\n",
      "Epoch: 087/100 | Batch 120/192 | Cost: 0.3785\n",
      "Epoch: 087/100 Train Acc.: 83.54% | Validation Acc.: 77.70%\n",
      "Time elapsed: 52.14 min\n",
      "Epoch: 088/100 | Batch 000/192 | Cost: 0.3773\n",
      "Epoch: 088/100 | Batch 120/192 | Cost: 0.5794\n",
      "Epoch: 088/100 Train Acc.: 83.13% | Validation Acc.: 77.00%\n",
      "Time elapsed: 52.74 min\n",
      "Epoch: 089/100 | Batch 000/192 | Cost: 0.4332\n",
      "Epoch: 089/100 | Batch 120/192 | Cost: 0.5796\n",
      "Epoch: 089/100 Train Acc.: 83.47% | Validation Acc.: 78.10%\n",
      "Time elapsed: 53.34 min\n",
      "Epoch: 090/100 | Batch 000/192 | Cost: 0.4397\n",
      "Epoch: 090/100 | Batch 120/192 | Cost: 0.4944\n",
      "Epoch: 090/100 Train Acc.: 83.38% | Validation Acc.: 78.10%\n",
      "Time elapsed: 53.94 min\n",
      "Epoch: 091/100 | Batch 000/192 | Cost: 0.3981\n",
      "Epoch: 091/100 | Batch 120/192 | Cost: 0.5605\n",
      "Epoch: 091/100 Train Acc.: 81.94% | Validation Acc.: 77.50%\n",
      "Time elapsed: 54.54 min\n",
      "Epoch: 092/100 | Batch 000/192 | Cost: 0.4216\n",
      "Epoch: 092/100 | Batch 120/192 | Cost: 0.4957\n",
      "Epoch: 092/100 Train Acc.: 83.90% | Validation Acc.: 78.60%\n",
      "Time elapsed: 55.14 min\n",
      "Epoch: 093/100 | Batch 000/192 | Cost: 0.4642\n",
      "Epoch: 093/100 | Batch 120/192 | Cost: 0.4166\n",
      "Epoch: 093/100 Train Acc.: 80.66% | Validation Acc.: 76.10%\n",
      "Time elapsed: 55.74 min\n",
      "Epoch: 094/100 | Batch 000/192 | Cost: 0.6470\n",
      "Epoch: 094/100 | Batch 120/192 | Cost: 0.4605\n",
      "Epoch: 094/100 Train Acc.: 82.41% | Validation Acc.: 77.70%\n",
      "Time elapsed: 56.34 min\n",
      "Epoch: 095/100 | Batch 000/192 | Cost: 0.5082\n",
      "Epoch: 095/100 | Batch 120/192 | Cost: 0.5747\n",
      "Epoch: 095/100 Train Acc.: 83.41% | Validation Acc.: 77.90%\n",
      "Time elapsed: 56.94 min\n",
      "Epoch: 096/100 | Batch 000/192 | Cost: 0.5011\n",
      "Epoch: 096/100 | Batch 120/192 | Cost: 0.4896\n",
      "Epoch: 096/100 Train Acc.: 83.62% | Validation Acc.: 77.80%\n",
      "Time elapsed: 57.53 min\n",
      "Epoch: 097/100 | Batch 000/192 | Cost: 0.4947\n",
      "Epoch: 097/100 | Batch 120/192 | Cost: 0.4069\n",
      "Epoch: 097/100 Train Acc.: 82.59% | Validation Acc.: 77.80%\n",
      "Time elapsed: 58.13 min\n",
      "Epoch: 098/100 | Batch 000/192 | Cost: 0.5170\n",
      "Epoch: 098/100 | Batch 120/192 | Cost: 0.5355\n",
      "Epoch: 098/100 Train Acc.: 84.31% | Validation Acc.: 80.50%\n",
      "Time elapsed: 58.73 min\n",
      "Epoch: 099/100 | Batch 000/192 | Cost: 0.3841\n",
      "Epoch: 099/100 | Batch 120/192 | Cost: 0.4383\n",
      "Epoch: 099/100 Train Acc.: 84.64% | Validation Acc.: 78.80%\n",
      "Time elapsed: 59.33 min\n",
      "Epoch: 100/100 | Batch 000/192 | Cost: 0.3343\n",
      "Epoch: 100/100 | Batch 120/192 | Cost: 0.5621\n",
      "Epoch: 100/100 Train Acc.: 83.69% | Validation Acc.: 78.20%\n",
      "Time elapsed: 59.93 min\n",
      "Total Training Time: 59.93 min\n"
     ]
    }
   ],
   "source": [
    "start_time = time.time()\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    \n",
    "    model.train()\n",
    "    \n",
    "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
    "    \n",
    "        ### PREPARE MINIBATCH\n",
    "        features = features.to(DEVICE)\n",
    "        targets = targets.to(DEVICE)\n",
    "            \n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits, probas = model(features)\n",
    "        cost = F.cross_entropy(logits, targets)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 120:\n",
    "            print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n",
    "                   f'Batch {batch_idx:03d}/{len(train_loader):03d} |' \n",
    "                   f' Cost: {cost:.4f}')\n",
    "\n",
    "    # no need to build the computation graph for backprop when computing accuracy\n",
    "    with torch.set_grad_enabled(False):\n",
    "        train_acc = compute_accuracy(model, train_loader, device=DEVICE)\n",
    "        valid_acc = compute_accuracy(model, valid_loader, device=DEVICE)\n",
    "        print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} Train Acc.: {train_acc:.2f}%'\n",
    "              f' | Validation Acc.: {valid_acc:.2f}%')\n",
    "        \n",
    "    elapsed = (time.time() - start_time)/60\n",
    "    print(f'Time elapsed: {elapsed:.2f} min')\n",
    "  \n",
    "elapsed = (time.time() - start_time)/60\n",
    "print(f'Total Training Time: {elapsed:.2f} min')"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "default_view": {},
   "name": "convnet-vgg16.ipynb",
   "provenance": [],
   "version": "0.3.2",
   "views": {}
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": true,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "371px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
